import matplotlib.pyplot as plt
import os
from utils import get_accuracy, get_samples, train_out_proj_fast, train_out_proj_closed_form
from main import Args
from data import load_SHD
from model import EchoSpike, simple_out
import numpy as np
from data import augment_shd
import torch
import seaborn as sns
from scipy.signal import savgol_filter
from tqdm.notebook import trange
from matplotlib import pyplot
pyplot.rcParams['figure.dpi'] = 600
import pickle
torch.manual_seed(0)
color_list = sns.color_palette('muted')
device = 'cpu'
batch_size = 64
folder = 'models/'
model_name = folder + 'shd_1layer_large.pt'
with open(model_name[:-3] + '_args.pkl', 'rb') as f:
args = pickle.load(f)
# args = Args()
online = args.online
print(vars(args))
/home/lars/miniconda3/lib/python3.9/site-packages/torchvision/datapoints/__init__.py:12: UserWarning: The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. While we do not expect major breaking changes, some APIs may still change according to user feedback. Please submit any feedback you may have in this issue: https://github.com/pytorch/vision/issues/6753, and you can also check out https://github.com/pytorch/vision/issues/7319 to learn more about the APIs that we suspect might involve future changes. You can silence this warning by calling torchvision.disable_beta_transforms_warning(). warnings.warn(_BETA_TRANSFORMS_WARNING) /home/lars/miniconda3/lib/python3.9/site-packages/torchvision/transforms/v2/__init__.py:54: UserWarning: The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. While we do not expect major breaking changes, some APIs may still change according to user feedback. Please submit any feedback you may have in this issue: https://github.com/pytorch/vision/issues/6753, and you can also check out https://github.com/pytorch/vision/issues/7319 to learn more about the APIs that we suspect might involve future changes. You can silence this warning by calling torchvision.disable_beta_transforms_warning(). warnings.warn(_BETA_TRANSFORMS_WARNING)
{'model_name': 'shd_1layer_large', 'dataset': 'shd', 'online': True, 'device': 'cuda', 'recurrency_type': 'none', 'lr': 0.0001, 'epochs': 1000, 'augment': True, 'batch_size': 128, 'n_hidden': [1332], 'inp_thr': 0.05, 'c_y': [1.5, -1.5], 'n_inputs': 700, 'n_outputs': 20, 'n_time_bins': 100, 'beta': 0.95}
Spiking Heidelberg Digits
#train_loader, test_loader = load_PMNIST(n_time_bins, scale=0.9, patches=True) #load_NMNIST(n_time_bins, batch_size=batch_size)
n_time_bins = 100
train_loader, test_loader = load_SHD(batch_size=batch_size) #load_NMNIST(n_time_bins, batch_size=batch_size)
# Plot Example(s)
for i in range(1):
frames, target = train_loader.next_item(-1, contrastive=True)
plt.figure(figsize=(10, 10))
plt.axis('off')
plt.imshow(frames.squeeze(1).T)
# plt.colorbar()
print(frames.shape, target)
plt.axis('on')
/home/lars/ownCloud/ETH/Master/Project_2/SNN_CLAPP/data.py:28: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). self.y = torch.tensor(y)
torch.Size([100, 1, 700]) tensor([4.])
(-0.5, 99.5, 699.5, -0.5)
SNN = EchoSpike(args.n_inputs, args.n_hidden, beta=args.beta, c_y=args.c_y, device=device, recurrency_type=args.recurrency_type, online=args.online).to(device)
SNN.load_state_dict(torch.load(model_name, map_location=device))
# train(SNN, train_loader, args.epochs, device, args.model_name,
# batch_size=args.batch_size, online=args.online, lr=1e-8, augment=args.augment)
from_epoch = 0
echo_train_loss = torch.load(model_name[:-3]+'_loss_hist.pt', map_location='cpu')[int(from_epoch*len(train_loader)/args.batch_size):]
print(echo_train_loss.shape)
for i in range(echo_train_loss.shape[-1]):
plt.plot(from_epoch+(args.batch_size*np.arange(echo_train_loss.shape[0])/len(train_loader)), savgol_filter(echo_train_loss[:,i], 99, 1), color=color_list[i])
plt.legend([f'layer {i+1}' for i in range(len(SNN.layers))])
# no y ticks, because it's not really meaningful
plt.yticks([])
plt.xlabel('Epoch')
plt.ylabel('EchoSpike Loss')
torch.Size([63719, 1])
Text(0, 0.5, 'EchoSpike Loss')
# plotting adaptive threshold and update rate for an example
# init_echo, label_0 = train_loader.next_item(-1, contrastive=True)
# sample_1, label_1 = train_loader.next_item(-1, contrastive=True)
# sample_2, label_2 = train_loader.next_item(label_1, contrastive=False)
print(label_0, label_1, label_2)
SNN.eval()
with torch.no_grad():
# feed first sample to get initial activity
for t in range(100):
inp_activity = init_echo[t].mean(axis=-1)
SNN(init_echo[t], torch.tensor(-1, device=device), inp_activity=inp_activity)
SNN.reset(-1)
# feed second sample to get the update rates and thresholds for contrastive case
contrastive_thresholds = torch.zeros(100)
contrastive_temp_sim = torch.zeros((len(SNN.layers), 100))
for t in range(100):
inp_activity = sample_1[t].mean(axis=-1)
out_spk, mems, losses = SNN(sample_1[t], torch.tensor(-1, device=device), inp_activity=inp_activity)
contrastive_thresholds[t] = inp_activity * args.c_y[1]
contrastive_temp_sim[:, t] = losses
SNN.reset(-1)
# feed third sample to get the update rates and thresholds for predictive case
predictive_thresholds = torch.zeros(100)
predictive_temp_sim = torch.zeros((len(SNN.layers), 100))
for t in range(100):
inp_activity = sample_2[t].mean(axis=-1)
out_spk, mems, losses = SNN(sample_1[t], torch.tensor(1, device=device), inp_activity=inp_activity)
predictive_thresholds[t] = inp_activity * args.c_y[0]
predictive_temp_sim[:, t] = -losses
SNN.reset(1)
# plot thresholds, with sample as background
layer = 2
fig, ax = plt.subplots(figsize=(10, 5))
ax2 = ax.twinx()
# imshow in background
ax.imshow(sample_1.squeeze(1).T, aspect='auto', cmap='Reds')
ax2.plot(-contrastive_temp_sim[layer], color='r', label='Negative Similarity Score')
ax2.plot(contrastive_thresholds, color='r', linestyle='--', label='Contrastive Threshold')
ax2.hlines(args.inp_thr*args.c_y[1], 0, 100, color='r', linestyle=':', label='Input Threshold (times c(-1))')
# highlight regions where the thresholds are crossed
argwhere = np.argwhere(np.logical_and((-contrastive_temp_sim[layer] < contrastive_thresholds).numpy(), contrastive_thresholds.numpy() < args.inp_thr*args.c_y[1]))
for i in range(argwhere.shape[0]):
ax2.axvspan(argwhere[i], argwhere[i]+1, color='r', alpha=0.2, lw=0)
ax.yaxis.set_visible(False)
ax2.spines['right'].set_visible(False)
ax2.yaxis.tick_left()
ax2.yaxis.set_label_position('left')
ax2.set_xlim(ax.get_xlim())
# get rid of right margin
ax2.margins(0)
ax.set_xlabel('Timesteps')
plt.ylabel('Thresholds & Similarity Score')
plt.xlim(0, 100)
plt.legend()
# same for predictive
fig, ax = plt.subplots(figsize=(10, 5))
ax2 = ax.twinx()
# imshow in background
ax.imshow(sample_2.squeeze(1).T, aspect='auto', cmap='Blues')
ax2.plot(predictive_temp_sim[layer], color='b', label='Similarity Score')
ax2.plot(predictive_thresholds, color='b', linestyle='--', label='Predictive Threshold')
ax2.hlines(args.inp_thr*args.c_y[0], 0, 100, color='b', linestyle=':', label='Input Threshold (times c(1))')
# highlight regions where the thresholds are crossed
argwhere = np.argwhere(np.logical_and((predictive_temp_sim[layer] < predictive_thresholds).numpy(), predictive_thresholds.numpy() > args.inp_thr*args.c_y[0]))
for i in range(argwhere.shape[0]):
ax2.axvspan(argwhere[i], argwhere[i]+1, color='b', alpha=0.1, lw=0)
ax.yaxis.set_visible(False)
ax2.spines['right'].set_visible(False)
ax2.yaxis.tick_left()
ax2.yaxis.set_label_position('left')
ax2.set_xlim(ax.get_xlim())
# get rid of right margin
#ax2.margins(0)
ax.set_xlabel('Timesteps')
plt.ylabel('Thresholds & Similarity Score')
plt.xlim(0, 100)
plt.legend()
plt.show()
tensor([9.]) tensor([7.]) tensor([7.])
/home/lars/miniconda3/lib/python3.9/site-packages/matplotlib/patches.py:1111: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray. xy = np.asarray(xy) /home/lars/miniconda3/lib/python3.9/site-packages/matplotlib/patches.py:1111: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray. xy = np.asarray(xy)
layers = [SNN.layers[0].fc.weight[:,:args.n_inputs]]
for i in range(1, len(SNN.layers)):
layers.append(SNN.layers[i].fc.weight[:,:args.n_hidden[i-1]] @ layers[-1])
for i in range(len(SNN.layers)):
plt.figure()
plt.imshow(SNN.layers[i].fc.weight.detach(), cmap='viridis')
plt.colorbar()
# plt.figure()
# plt.imshow(SNN.layers[i].pred.weight.detach(), vmax=0.5, vmin=-0.5)
# plt.colorbar()
for lay in layers:
plt.figure()
plt.imshow(lay.detach())
plt.colorbar()
/tmp/ipykernel_19791/2587452636.py:8: MatplotlibDeprecationWarning: Auto-removal of grids by pcolor() and pcolormesh() is deprecated since 3.5 and will be removed two minor releases later; please call grid(False) first. plt.colorbar() /tmp/ipykernel_19791/2587452636.py:15: MatplotlibDeprecationWarning: Auto-removal of grids by pcolor() and pcolormesh() is deprecated since 3.5 and will be removed two minor releases later; please call grid(False) first. plt.colorbar()